# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains TPU-specific Pallas abstractions."""
from __future__ import annotations

import collections
from collections.abc import Sequence
import dataclasses
import enum
from typing import Any, ClassVar, Literal
from collections.abc import Mapping

import jax
import jax.numpy as jnp
from jax.extend import backend as jex_backend
from jax._src import core as jax_core
from jax._src import state
from jax._src import util
from jax._src.frozen_dict import FrozenDict
from jax._src.pallas import core as pallas_core
import numpy as np


map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

no_block_spec = pallas_core.no_block_spec
_out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping


class KernelType(enum.Enum):
  TC = 0
  SC_SCALAR_SUBCORE = 1
  SC_VECTOR_SUBCORE = 2


class GridDimensionSemantics(enum.Enum):
  PARALLEL = "parallel"
  CORE_PARALLEL = "core_parallel"
  SUBCORE_PARALLEL = "subcore_parallel"
  ARBITRARY = "arbitrary"

PARALLEL = GridDimensionSemantics.PARALLEL
CORE_PARALLEL = GridDimensionSemantics.CORE_PARALLEL
SUBCORE_PARALLEL = GridDimensionSemantics.SUBCORE_PARALLEL
ARBITRARY = GridDimensionSemantics.ARBITRARY


DimensionSemantics = (
    Literal["parallel", "core_parallel", "subcore_parallel", "arbitrary"]
    | GridDimensionSemantics
)


class SideEffectType(enum.Enum):
  # No side effects, can be deduplicated / removed if unused.
  PURE = "pure"
  # Cannot be deduplicated, but can be removed if unused.
  DATAFLOW_SIDE_EFFECTING = "dataflow_side_effecting"
  # Cannot be deduplicated or removed.
  SIDE_EFFECTING = "side_effecting"


@dataclasses.dataclass(frozen=True)
class CompilerParams(pallas_core.CompilerParams):
  """Mosaic TPU compiler parameters.

  Attributes:
    dimension_semantics: A list of dimension semantics for each grid dimension
      of the kernel. Either "parallel" for dimensions that can execute in any
      order, or "arbitrary" for dimensions that must be executed sequentially.
    allow_input_fusion: A list of booleans indicating whether input fusion is
      allowed for each argument.
    vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note that
      this must be used in conjunction with the
      --xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes.
    collective_id: Indicates which barrier semaphore to use for the kernel. Note
      that using the same collective_id does not guarantee that the same barrier
      semaphore will be allocated between kernels.
    has_side_effects: Set to True to prevent kernel being CSEd by XLA.
    flags: A dictionary of command line flags for the kernel.
    internal_scratch_in_bytes: The size of the internal scratch space used by
      Mosaic.
    serialization_format: The serialization format for the kernel body.
    kernel_type: Specify if the kernel is meant to run on TensorCore or one of
      the SparseCores
    disable_bounds_checks: Disable bounds checks in the kernel.
    skip_device_barrier: Skip the default device barrier for the kernel.
    allow_collective_id_without_custom_barrier: Allow the use of collective_id
      without a custom barrier.
  """
  BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu"
  dimension_semantics: tuple[DimensionSemantics, ...] | None = None
  allow_input_fusion: tuple[bool, ...] | None = None
  vmem_limit_bytes: int | None = None
  collective_id: int | None = None
  has_side_effects: bool | SideEffectType = False
  flags: dict[str, Any] | None = None
  internal_scratch_in_bytes: int | None = None
  serialization_format: int = 1
  kernel_type: KernelType = KernelType.TC
  disable_bounds_checks: bool = False
  skip_device_barrier: bool = False
  allow_collective_id_without_custom_barrier: bool = False
  shape_invariant_numerics: bool = True

  def __init__(
      self,
      dimension_semantics: Sequence[DimensionSemantics] | None = None,
      allow_input_fusion: Sequence[bool] | None = None,
      vmem_limit_bytes: int | None = None,
      collective_id: int | None = None,
      has_side_effects: bool | SideEffectType = False,
      flags: Mapping[str, Any] | None = None,
      internal_scratch_in_bytes: int | None = None,
      serialization_format: int = 1,
      kernel_type: KernelType = KernelType.TC,
      disable_bounds_checks: bool = False,
      skip_device_barrier: bool = False,
      allow_collective_id_without_custom_barrier: bool = False,
      shape_invariant_numerics: bool = True,
  ):
    object.__setattr__(
        self,
        "dimension_semantics",
        None if dimension_semantics is None else tuple(dimension_semantics),
    )
    object.__setattr__(
        self,
        "allow_input_fusion",
        None if allow_input_fusion is None else tuple(allow_input_fusion),
    )
    object.__setattr__(self, "vmem_limit_bytes", vmem_limit_bytes)
    object.__setattr__(self, "collective_id", collective_id)
    object.__setattr__(self, "has_side_effects", has_side_effects)
    object.__setattr__(
        self, "flags", None if flags is None else FrozenDict(flags)
    )
    object.__setattr__(
        self, "internal_scratch_in_bytes", internal_scratch_in_bytes
    )
    object.__setattr__(self, "serialization_format", serialization_format)
    object.__setattr__(self, "kernel_type", kernel_type)
    object.__setattr__(self, "disable_bounds_checks", disable_bounds_checks)
    object.__setattr__(self, "skip_device_barrier", skip_device_barrier)
    object.__setattr__(
        self,
        "allow_collective_id_without_custom_barrier",
        allow_collective_id_without_custom_barrier,
    )
    object.__setattr__(
        self, "shape_invariant_numerics", shape_invariant_numerics
    )

  # Replace is a method, not a field.
  replace = dataclasses.replace

class MemorySpace(enum.Enum):
  ANY = "any"  # TODO(b/368401328): Remove this and just use pl.ANY.
  VMEM = "vmem"
  VMEM_SHARED = "vmem_shared"
  SMEM = "smem"
  CMEM = "cmem"
  SEMAPHORE = "semaphore_mem"
  HBM = "hbm"
  HOST = "host"

  def __str__(self) -> str:
    return self.value

  def from_type(self, ty):
    return pallas_core.MemoryRef(ty, memory_space=self)

  def __call__(self, shape: Sequence[int], dtype: jnp.dtype):
    # A convenience function for constructing MemoryRef types of ShapedArrays.
    return self.from_type(jax_core.ShapedArray(tuple(shape), dtype))

class dma_semaphore(pallas_core.semaphore_dtype): pass

class DMASemaphore(pallas_core.AbstractSemaphoreTy):
  type = dma_semaphore
  name = "dma_sem"

class SemaphoreType(enum.Enum):
  REGULAR = "regular"
  DMA = "dma"
  BARRIER = "barrier"

  def __call__(self, shape: tuple[int, ...]):
    dtype: Any
    if self == SemaphoreType.DMA:
      dtype = DMASemaphore()
    elif self == SemaphoreType.BARRIER:
      dtype = pallas_core.BarrierSemaphore()
    else:
      dtype = pallas_core.Semaphore()
    return pallas_core.MemoryRef(jax_core.ShapedArray(shape, dtype),
                                 MemorySpace.SEMAPHORE)

  def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace:
    return self(()).get_array_aval()

  def get_ref_aval(self) -> state.AbstractRef:
    return self(()).get_ref_aval()

@dataclasses.dataclass(frozen=True)
class AbstractSemaphore(jax_core.AbstractValue):
  sem_type: SemaphoreType


@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
  num_scalar_prefetch: int

  def __init__(
      self,
      num_scalar_prefetch: int,
      grid: pallas_core.Grid = (),
      in_specs: pallas_core.BlockSpecTree = no_block_spec,
      out_specs: pallas_core.BlockSpecTree = no_block_spec,
      scratch_shapes: pallas_core.ScratchShapeTree = ()
  ):
    super().__init__(grid, in_specs, out_specs, scratch_shapes)
    self.num_scalar_prefetch = num_scalar_prefetch
    self.scratch_shapes = tuple(scratch_shapes)

  def _make_scalar_ref_aval(self, aval):
    return state.AbstractRef(jax_core.ShapedArray(aval.shape, aval.dtype),
                             MemorySpace.SMEM)


@dataclasses.dataclass(frozen=True)
class TensorCore:
  id: int


@dataclasses.dataclass(frozen=True)
class TensorCoreMesh:
  """A mesh of TensorCores."""
  devices: np.ndarray
  axis_names: Sequence[str]

  def __init__(self, devices: np.ndarray, axis_names: Sequence[str]):
    devices = np.copy(devices)
    devices.setflags(write=False)
    object.__setattr__(self, "devices", devices)
    object.__setattr__(self, "axis_names", tuple(axis_names))

  def __hash__(self) -> int:
    return hash(
        (self.devices.shape, tuple(np.ravel(self.devices)), self.axis_names)
    )

  @property
  def backend(self) -> str:
    return "mosaic_tpu"

  @property
  def shape(self):
    return collections.OrderedDict(zip(self.axis_names, self.devices.shape))

  def discharges_effect(self, effect: jax_core.Effect):
    del effect
    return False


def create_tensorcore_mesh(
    axis_name: str,
    devices: Sequence[jax.Device] | None = None,
    num_cores: int | None = None,
) -> TensorCoreMesh:
  if devices is not None and num_cores is not None:
    raise ValueError('cannot specify both devices and num_cores')
  if num_cores is None:
    if devices is None:
      abstract_device = jax.sharding.get_abstract_mesh().abstract_device
      if abstract_device is None:
        devices = [jax.devices()[0]]
      else:
        devices = [abstract_device]
    num_cores = devices[0].num_cores
  return TensorCoreMesh(
      np.array([TensorCore(i) for i in range(num_cores)]),
      [axis_name],
  )


def _tensorcore_mesh_discharge_rule(
    in_avals,
    out_avals,
    *args,
    mesh,
    jaxpr,
    compiler_params: Any | None,
    interpret: Any,
    debug: bool,
    cost_estimate: pallas_core.CostEstimate | None,
    name: str,
    metadata: FrozenDict[str, str] | None,
):
  assert isinstance(mesh, TensorCoreMesh)
  if compiler_params and not isinstance(compiler_params, CompilerParams):
    raise ValueError(
        "compiler_params must be a pltpu.CompilerParams"
    )
  if not compiler_params:
    compiler_params = CompilerParams()
  if len(mesh.shape) > 1:
    raise NotImplementedError("Mesh must be 1D")
  if compiler_params.dimension_semantics is not None:
    raise ValueError(
        "dimension_semantics must be None for TensorCoreMesh"
    )
  num_cores = len(mesh.devices)
  if num_cores > 1:
    # Since each core will have its own VMEM, we currently disallow VMEM inputs
    # and outputs since other ops might not agree on how they are sharded across
    # cores by the (core-mapped) kernel.
    if any(
        pallas_core.get_memory_space_aval(aval) == MemorySpace.VMEM
        for aval in in_avals
    ):
      raise NotImplementedError(
          "TensorCoreMesh does not support VMEM inputs/outputs when there are"
          " >1 cores. Use HBM or ANY instead."
      )
  return pallas_core.default_mesh_discharge_rule(
      in_avals,
      out_avals,
      *args,
      jaxpr=jaxpr,
      mesh=mesh,
      compiler_params=compiler_params.replace(
          dimension_semantics=(PARALLEL,)
      ),
      debug=debug,
      interpret=interpret,
      cost_estimate=cost_estimate,
      name=name,
      metadata=metadata,
      scratch_shapes=[],
  )

pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
    _tensorcore_mesh_discharge_rule
)


def _convert_semaphore_type_to_aval(
    out_shape: SemaphoreType,
) -> jax_core.AbstractValue:
  return out_shape.get_array_aval()


pallas_core._out_shape_to_aval_mapping[SemaphoreType] = (
    _convert_semaphore_type_to_aval
)


def get_device_kind() -> str:
  if abstract_device := jax.sharding.get_abstract_mesh().abstract_device:
    return abstract_device.device_kind
  return jex_backend.get_default_device().device_kind

def get_num_device_cores() -> int:
  if abstract_device := jax.sharding.get_abstract_mesh().abstract_device:
    return abstract_device.num_cores
  return jex_backend.get_default_device().num_cores
